import numpy as np
from rich import print

class Metrics:
    def __init__(self, policies):
        self.metrics = {}
        self.focal_agents = []
        for agent_id, policy in policies.items():
            self.metrics[agent_id] = {}
            self.metrics[agent_id]["reward"] = []
            self.metrics[agent_id]["collision"] = []
            self.metrics[agent_id]["route_completion"] = []
            self.metrics[agent_id]["timeout"] = []
            if policy.is_focal:
                self.focal_agents.append(agent_id)
        self.metrics["focal"] = {}
        self.metrics["focal"]["reward"] = []
        self.metrics["focal"]["collision"] = []
        self.metrics["focal"]["timeout"] = []
        self.num_episodes = 0

    def update(self, info):
        # make sure the results are updated per episode
        self.num_episodes += 1
        sum_focal_rewards = 0
        num_collisions = 0
        num_timeout = 0
        for key, result in info.items():
            if key in self.metrics:
                agent_id = key
                if agent_id in self.focal_agents:
                    sum_focal_rewards += result["reward"]
                    num_collisions += result["collision"]
                    num_timeout += result["timeout"]
                self.metrics[agent_id]["reward"].append(result["reward"])
                self.metrics[agent_id]["collision"].append(result["collision"])
                self.metrics[agent_id]["route_completion"].append(result["route_completion"])
                self.metrics[agent_id]["timeout"].append(result["timeout"])
        self.metrics["focal"]["reward"].append(sum_focal_rewards)
        self.metrics["focal"]["collision"].append(num_collisions)
        self.metrics["focal"]["timeout"].append(num_timeout)

    def check_early_stopping(self, early_stopping_iterations):
        early_stopping_iterations = int(early_stopping_iterations)
        for agent_id in self.focal_agents:
            # All focal agents has solved the scenario
            if len(self.metrics[agent_id]["reward"]) < early_stopping_iterations:
                return False
            if np.sum(self.metrics[agent_id]["reward"][-early_stopping_iterations:]) < early_stopping_iterations:
                return False
        return True

    def report(self):
        print("===============================================")
        print(f"Aggregating over {self.num_episodes} episodes")
        metric_report = {}
        for agent_id, metric in self.metrics.items():
            if agent_id == "focal":
                # report focal group separately
                continue
            if len(metric["reward"]) == 0:
                print(f"Agent {agent_id} has no data")
                continue
            print(f"Agent {agent_id}")
            metric_report[agent_id] = {}
            print(f"Average Reward: {np.mean(metric['reward'])}, {np.std(metric['reward'])}")
            metric_report[agent_id]["reward"] = (np.mean(metric['reward']), np.std(metric['reward']))
            print(f"Average Collision: {np.mean(metric['collision'])}, {np.std(metric['collision'])}")
            metric_report[agent_id]["collision"] = (np.mean(metric['collision']), np.std(metric['collision']))
            print(f"Average Route Completion: {np.mean(metric['route_completion'])}, {np.std(metric['route_completion'])}")
            metric_report[agent_id]["route_completion"] = (np.mean(metric['route_completion']), np.std(metric['route_completion']))
            print(f"Average Timeout: {np.mean(metric['timeout'])}, {np.std(metric['timeout'])}")
            metric_report[agent_id]["timeout"] = (np.mean(metric['timeout']), np.std(metric['timeout']))
            print("-----------------------------------------------")
        print(f"Focal Agents")
        metric_report["focal"] = {}
        print(f"Average Reward: {np.mean(self.metrics['focal']['reward'])}, {np.std(self.metrics['focal']['reward'])}")
        metric_report["focal"]["reward"] = (np.mean(self.metrics['focal']['reward']), np.std(self.metrics['focal']['reward']))
        print(f"Average Collision: {np.mean(self.metrics['focal']['collision'])}, {np.std(self.metrics['focal']['collision'])}")
        metric_report["focal"]["collision"] = (np.mean(self.metrics['focal']['collision']), np.std(self.metrics['focal']['collision']))
        print(f"Average Timeout: {np.mean(self.metrics['focal']['timeout'])}, {np.std(self.metrics['focal']['timeout'])}")
        metric_report["focal"]["timeout"] = (np.mean(self.metrics['focal']['timeout']), np.std(self.metrics['focal']['timeout']))
        print("===============================================")
        return metric_report